{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.19847283, -0.9801064 ,  0.23964715], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAitElEQVR4nO3df3DU9YH/8ddu9kd+sRsSyC6RpKRC5XIItYCw1z86d+RIbabVys1Zh/M4ZerIBQek44zcKU777UwY+/1eW++U3kyn6h9VbnIt9uSgNhMwXM/wK0KLqFGnYCKwGyFmNwnJJtl9f//I5XOuoE0gyb4Tn4+ZnTGfz3uz78+7zD772f3sxmWMMQIAwELubE8AAIBPQqQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANbKWqSeeuopLViwQLm5uVq1apWOHj2arakAACyVlUj927/9m7Zt26bHH39cr732mpYtW6aamhp1dnZmYzoAAEu5svEFs6tWrdLKlSv1L//yL5KkdDqt8vJyPfjgg3rkkUemejoAAEt5pvoBBwcH1draqu3btzvb3G63qqur1dLSctX7JJNJJZNJ5+d0Oq2uri6VlJTI5XJN+pwBABPLGKOenh6VlZXJ7f7kF/WmPFIXL15UKpVSKBTK2B4KhfTWW29d9T719fX67ne/OxXTAwBMoY6ODs2fP/8T9095pK7F9u3btW3bNufneDyuiooKdXR0KBAIZHFmAIBrkUgkVF5erlmzZn3quCmP1Jw5c5STk6NYLJaxPRaLKRwOX/U+fr9ffr//iu2BQIBIAcA09sfespnyq/t8Pp+WL1+upqYmZ1s6nVZTU5MikchUTwcAYLGsvNy3bds2bdiwQStWrNCtt96qH/3oR+rr69O9996bjekAACyVlUjddddd+uCDD7Rjxw5Fo1F98Ytf1K9//esrLqYAAHy2ZeVzUtcrkUgoGAwqHo/znhQATENjfR7nu/sAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWGvckTp06JC+/vWvq6ysTC6XSy+++GLGfmOMduzYoXnz5ikvL0/V1dV65513MsZ0dXVp/fr1CgQCKioq0saNG9Xb23tdBwIAmHnGHam+vj4tW7ZMTz311FX3P/HEE3ryySf1k5/8REeOHFFBQYFqamo0MDDgjFm/fr1Onz6txsZG7d27V4cOHdL9999/7UcBAJiZzHWQZPbs2eP8nE6nTTgcNj/4wQ+cbd3d3cbv95sXXnjBGGPMG2+8YSSZY8eOOWP2799vXC6XOXfu3JgeNx6PG0kmHo9fz/QBAFky1ufxCX1P6syZM4pGo6qurna2BYNBrVq1Si0tLZKklpYWFRUVacWKFc6Y6upqud1uHTly5Kq/N5lMKpFIZNwAADPfhEYqGo1KkkKhUMb2UCjk7ItGoyotLc3Y7/F4VFxc7Iz5uPr6egWDQedWXl4+kdMGAFhqWlzdt337dsXjcefW0dGR7SkBAKbAhEYqHA5LkmKxWMb2WCzm7AuHw+rs7MzYPzw8rK6uLmfMx/n9fgUCgYwbAGDmm9BIVVZWKhwOq6mpydmWSCR05MgRRSIRSVIkElF3d7daW1udMQcOHFA6ndaqVasmcjoAgGnOM9479Pb26t1333V+PnPmjE6ePKni4mJVVFRo69at+v73v69FixapsrJSjz32mMrKynTHHXdIkv7kT/5EX/3qV/Xtb39bP/nJTzQ0NKTNmzfrW9/6lsrKyibswAAAM8B4Lxs8ePCgkXTFbcOGDcaYkcvQH3vsMRMKhYzf7zdr1qwxbW1tGb/j0qVL5u677zaFhYUmEAiYe++91/T09Ez4pYsAADuN9XncZYwxWWzkNUkkEgoGg4rH47w/BQDT0Fifx6fF1X0AgM8mIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLXGFan6+nqtXLlSs2bNUmlpqe644w61tbVljBkYGFBdXZ1KSkpUWFiodevWKRaLZYxpb29XbW2t8vPzVVpaqocffljDw8PXfzQAgBllXJFqbm5WXV2dDh8+rMbGRg0NDWnt2rXq6+tzxjz00EN66aWX1NDQoObmZp0/f1533nmnsz+VSqm2tlaDg4N69dVX9dxzz+nZZ5/Vjh07Ju6oAAAzg7kOnZ2dRpJpbm42xhjT3d1tvF6vaWhocMa8+eabRpJpaWkxxhizb98+43a7TTQadcbs2rXLBAIBk0wmx/S48XjcSDLxePx6pg8AyJKxPo9f13tS8XhcklRcXCxJam1t1dDQkKqrq50xixcvVkVFhVpaWiRJLS0tuvnmmxUKhZwxNTU1SiQSOn369FUfJ5lMKpFIZNwAADPfNUcqnU5r69at+vKXv6wlS5ZIkqLRqHw+n4qKijLGhkIhRaNRZ8xHAzW6f3Tf1dTX1ysYDDq38vLya502AGAaueZI1dXV6fXXX9fu3bsncj5XtX37dsXjcefW0dEx6Y8JAMg+z7XcafPmzdq7d68OHTqk+fPnO9vD4bAGBwfV3d2dcTYVi8UUDoedMUePHs34faNX/42O+Ti/3y+/338tUwUATGPjOpMyxmjz5s3as2ePDhw4oMrKyoz9y5cvl9frVVNTk7Otra1N7e3tikQikqRIJKJTp06ps7PTGdPY2KhAIKCqqqrrORYAwAwzrjOpuro6Pf/88/rVr36lWbNmOe8hBYNB5eXlKRgMauPGjdq2bZuKi4sVCAT04IMPKhKJaPXq1ZKktWvXqqqqSvfcc4+eeOIJRaNRPfroo6qrq+NsCQCQwWWMMWMe7HJddfszzzyjv/u7v5M08mHe73znO3rhhReUTCZVU1Ojp59+OuOlvPfee0+bNm3SK6+8ooKCAm3YsEE7d+6UxzO2ZiYSCQWDQcXjcQUCgbFOHwBgibE+j48rUrYgUgAwvY31eZzv7gMAWItIAQCsRaQAANa6ps9JAYBtjDFSOq304KBSly8rnUxKo2+5u91y+3zyFBbKzVXE0wqRAjBtmVRKg11dGrp0SYMXL2owGtXAhQu6/O67Gnj/fZl0Wkqn5c7Lkz8U0qxly1S0erUKFi6U2+fL9vQxBkQKgLUyLj42ZiRKH3ygy3/4g/rPntVAR4eGuro0nEhoOJFQ6iN/Nuij0pcvq//MGfWfPavuw4c1Z+1ahW6/XW6vd4qOBNeKSAGwSnp4WOn+fqUGBpTq69NgLKbLZ86o75131H/2rFI9PTKplEwqJaXT4/vlxmgwGtWF559XqrdX8/76r5WTnz85B4IJQaQAZFV6aEhDly4pGYtp8IMPlIzFlDx/XsloVMkLF5Tq7Z3wxzTDw/pg/355i4s197bbOKOyGJECMKmcl+xGL2wYHlbywoWRs6O339ZAe7tSly87L9el+/unZF7p/n5Ff/ELFS5erPxFiz7xG3WQXUQKwIRLDw0p1denVF+fhnt6nCj1nzmjvrfekkmnnYsalMUvvRn+8EN1HTqk/BtvlHJysjYPfDIiBeC6pQcHR16qu3BBA+fPa7Czc+Tlu1hMA+fPywwOZnuKn+iDfft0wz33yEWkrESkAPxRxpiRq+v+5+zHDA+r/+xZXT57VpfffXfkgob+fufsyQwNZXvKmCGIFIArGGNkBgc13Ns7cnn36Et2776r/vfeU//Zs0oPDo68VDf9vqMa0wiRAiBjjNLJpJIXLigZjWrg3DkNRqMavHhRyVhMQxcvjnyDAzDFiBTwGTH6tUEmlZIZHlZqYEADHR0jnz86c0b97e1KX76s1MCA0v39MsPD2Z7ylOBrkuxGpIAZyhij9MCAhhMJDXV3a+jDD5W8cGHkvaQzZ5S8cEGGsyPNv+8+uficlLWIFDCDpPr7NfD++xo4d06X//AHDY1+r92lSxrq6rL6KrtsKPjCF1T4p38ql5s/CGErIgVMI8YYmeHhkdvQkIZ7e52X7HrfekuDsZjSyeTIbWCAixo+ha+0VOFvfUv+cDjbU8GnIFKA5VKXL4+cEX344ciFDOfPa+D990e+YPXcuWxPb1ryhUKaf999Ci5fzjdNWI5IARYxxih1+bIGOjrU/957GmhvV7Kzc+Sbvru7NXjp0vi/VBUOV06OZi1bptJvfEOBL36RQE0DRAqwgEmlNNzTo65DhxT9xS9khoZkhob+97NIGD+XSy6fTzl+v1w+nwoXL1bJmjUqWLxYOXl5vA81TRApIMvSyaS6Dx9W7KWXdPmdd4jSdcgJBOQrLpZv7lz5QiHl3nCD8ioqlPu5z8kbCGR7ergGRArIImOMPvjNbxRtaNBwd3e2pzP9uN0qWLxY+ZWVyluwQP5wWJ5gUN6iInkCAc6WZgAiBWSJSaV06cABnf/5z5W+fDnb07GX262cvDy58/KUk5srf3m5ZlVVKX/RIuWGw3Ln5cmVkyOXxzPyEh/vM80oRArIkr6331a0oeGKQJ3r69OJri71DA1pbm6uInPnquCz9GFTl2vk5bo5c+QLheQPhZQ7f75yy8vlnzdPObm52Z4hphCRArIgPTSk+PHjSkajzjZjjM709urxEyd0trdXA6mUAl6vlsyerf+7cqW8M/GlK5dLcrnkLSlRwcKFyqusVH5lpbzFxfLMmiVPMKicvLxszxJZRKSALBi6dEmxX/4yY9sfent1/3//t+If+TMX8aEh/Xdnp7YcOaL/c8stKpnGZxEur1eewkLlFBbKU1Sk/AULlH/TTSq48UZ5i4vl8nhG3kP6nxjzsh0kIgVkhTFGJpXK2Paj06czAvVRRy9eVOP58/rW5z8/FdObGDk5yr3hBvlDIfnDYflKS5VbXq7cG26Qt6REbg9PP/jj+FcC4Pq43SNnQC6XcisqVLBokfI//3nllpfLEwzKU1CgnMJCuT9L76thwhApAGPncimnoGDkZbtAQL65c5V/440qWLRIeQsWKKegYORluv+58ZIdrheRAixRW16u4xcvaugqH+ZdUFiopcXFWZiV5AkE5C8rk3/evJGX78Lhkavu5s2TZ9YsQoRJRaSALPAGgyr+8z9X18GDzraasjJJ0vd/9zsNplJKS8pxuVTk8+n/rVypzxUWTt6EXK6Rzxrl5Eg5OSpcvFgFX/iC8m+8Uf5wWDn5+XLn5/N1QphyRArIAndenmZHIoofP65UT4+kkavZasrKND8/X3vff1+XBga0oLBQd1VWqmSC/3qsKydn5P2iYHDkTGnevP+N0rx5cvt8Iy/ZjY7nbAlZQqSALHC5XArccovm3nabYr/4hXOln8vl0pLZs7Vk9uwJf0xfOKzcG24Y+WDsDTeMfGC2tFS+uXP5gCysRaSALHH7/Qp985sa+vBDdR08KDM8fN2/05WTI5fXK5fHI29x8cjZ0cKFKli4UJ6iIrn9fuXk5srl83F2hGmBSAFZlJOfr/n33SdvcbG6Dh7UYGfnuO7v8vnkKymRt7hY3jlzlDt/vvI//3nlL1ggb3Gx88FYZzxhwjRDpIAscrlc8hQUaN5f/ZUCS5fqw1dfVe/p00pGo0onkyPvC330jxy63cpbsEC55eXKq6iQPxRyvufOM3s2H5DFjMO/aMACbr9fhUuWqOALX1Cqv19meFgmndbFl19Wb1ubChcvVv6NN458Fik/X26vV26/f+RqPGAGI1KAJVwul1x+v9wfuZKv7G/+5ooxwGcJkQIsRpTwWcen8gAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYa1yR2rVrl5YuXapAIKBAIKBIJKL9+/c7+wcGBlRXV6eSkhIVFhZq3bp1isViGb+jvb1dtbW1ys/PV2lpqR5++GENDw9PzNEAAGaUcUVq/vz52rlzp1pbW3X8+HH9xV/8hW6//XadPn1akvTQQw/ppZdeUkNDg5qbm3X+/Hndeeedzv1TqZRqa2s1ODioV199Vc8995yeffZZ7dixY2KPCgAwM5jrNHv2bPPTn/7UdHd3G6/XaxoaGpx9b775ppFkWlpajDHG7Nu3z7jdbhONRp0xu3btMoFAwCSTyTE/ZjweN5JMPB6/3ukDALJgrM/j1/yeVCqV0u7du9XX16dIJKLW1lYNDQ2purraGbN48WJVVFSopaVFktTS0qKbb75ZoVDIGVNTU6NEIuGcjV1NMplUIpHIuAEAZr5xR+rUqVMqLCyU3+/XAw88oD179qiqqkrRaFQ+n09FRUUZ40OhkKLRqCQpGo1mBGp0/+i+T1JfX69gMOjcysvLxzttAMA0NO5I3XTTTTp58qSOHDmiTZs2acOGDXrjjTcmY26O7du3Kx6PO7eOjo5JfTwAgB08472Dz+fTwoULJUnLly/XsWPH9OMf/1h33XWXBgcH1d3dnXE2FYvFFA6HJUnhcFhHjx7N+H2jV/+Njrkav98vv98/3qkCAKa56/6cVDqdVjKZ1PLly+X1etXU1OTsa2trU3t7uyKRiCQpEono1KlT6uzsdMY0NjYqEAioqqrqeqcCAJhhxnUmtX37dt12222qqKhQT0+Pnn/+eb3yyit6+eWXFQwGtXHjRm3btk3FxcUKBAJ68MEHFYlEtHr1aknS2rVrVVVVpXvuuUdPPPGEotGoHn30UdXV1XGmBAC4wrgi1dnZqb/927/VhQsXFAwGtXTpUr388sv6y7/8S0nSD3/4Q7ndbq1bt07JZFI1NTV6+umnnfvn5ORo79692rRpkyKRiAoKCrRhwwZ973vfm9ijAgDMCC5jjMn2JMYrkUgoGAwqHo8rEAhkezoAgHEa6/M4390HALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFrXFamdO3fK5XJp69atzraBgQHV1dWppKREhYWFWrdunWKxWMb92tvbVVtbq/z8fJWWlurhhx/W8PDw9UwFADADXXOkjh07pn/913/V0qVLM7Y/9NBDeumll9TQ0KDm5madP39ed955p7M/lUqptrZWg4ODevXVV/Xcc8/p2Wef1Y4dO679KAAAM5O5Bj09PWbRokWmsbHRfOUrXzFbtmwxxhjT3d1tvF6vaWhocMa++eabRpJpaWkxxhizb98+43a7TTQadcbs2rXLBAIBk0wmx/T48XjcSDLxePxapg8AyLKxPo9f05lUXV2damtrVV1dnbG9tbVVQ0NDGdsXL16siooKtbS0SJJaWlp08803KxQKOWNqamqUSCR0+vTpqz5eMplUIpHIuAEAZj7PeO+we/duvfbaazp27NgV+6LRqHw+n4qKijK2h0IhRaNRZ8xHAzW6f3Tf1dTX1+u73/3ueKcKAJjmxnUm1dHRoS1btujnP/+5cnNzJ2tOV9i+fbvi8bhz6+jomLLHBgBkz7gi1draqs7OTn3pS1+Sx+ORx+NRc3OznnzySXk8HoVCIQ0ODqq7uzvjfrFYTOFwWJIUDoevuNpv9OfRMR/n9/sVCAQybgCAmW9ckVqzZo1OnTqlkydPOrcVK1Zo/fr1zn97vV41NTU592lra1N7e7sikYgkKRKJ6NSpU+rs7HTGNDY2KhAIqKqqaoIOCwAwE4zrPalZs2ZpyZIlGdsKCgpUUlLibN+4caO2bdum4uJiBQIBPfjgg4pEIlq9erUkae3ataqqqtI999yjJ554QtFoVI8++qjq6urk9/sn6LAAADPBuC+c+GN++MMfyu12a926dUomk6qpqdHTTz/t7M/JydHevXu1adMmRSIRFRQUaMOGDfre97430VMBAExzLmOMyfYkxiuRSCgYDCoej/P+FABMQ2N9Hue7+wAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1vJkewLXwhgjSUokElmeCQDgWow+f48+n3+SaRmpS5cuSZLKy8uzPBMAwPXo6elRMBj8xP3TMlLFxcWSpPb29k89uM+6RCKh8vJydXR0KBAIZHs61mKdxoZ1GhvWaWyMMerp6VFZWdmnjpuWkXK7R95KCwaD/CMYg0AgwDqNAes0NqzT2LBOf9xYTjK4cAIAYC0iBQCw1rSMlN/v1+OPPy6/35/tqViNdRob1mlsWKexYZ0mlsv8sev/AADIkml5JgUA+GwgUgAAaxEpAIC1iBQAwFrTMlJPPfWUFixYoNzcXK1atUpHjx7N9pSm1KFDh/T1r39dZWVlcrlcevHFFzP2G2O0Y8cOzZs3T3l5eaqurtY777yTMaarq0vr169XIBBQUVGRNm7cqN7e3ik8islVX1+vlStXatasWSotLdUdd9yhtra2jDEDAwOqq6tTSUmJCgsLtW7dOsVisYwx7e3tqq2tVX5+vkpLS/Xwww9reHh4Kg9lUu3atUtLly51PngaiUS0f/9+Zz9rdHU7d+6Uy+XS1q1bnW2s1SQx08zu3buNz+czP/vZz8zp06fNt7/9bVNUVGRisVi2pzZl9u3bZ/7xH//R/PKXvzSSzJ49ezL279y50wSDQfPiiy+a3/3ud+Yb3/iGqaysNP39/c6Yr371q2bZsmXm8OHD5r/+67/MwoULzd133z3FRzJ5ampqzDPPPGNef/11c/LkSfO1r33NVFRUmN7eXmfMAw88YMrLy01TU5M5fvy4Wb16tfmzP/szZ//w8LBZsmSJqa6uNidOnDD79u0zc+bMMdu3b8/GIU2K//iP/zD/+Z//ad5++23T1tZm/uEf/sF4vV7z+uuvG2NYo6s5evSoWbBggVm6dKnZsmWLs521mhzTLlK33nqrqaurc35OpVKmrKzM1NfXZ3FW2fPxSKXTaRMOh80PfvADZ1t3d7fx+/3mhRdeMMYY88YbbxhJ5tixY86Y/fv3G5fLZc6dOzdlc59KnZ2dRpJpbm42xoysidfrNQ0NDc6YN99800gyLS0txpiR/zPgdrtNNBp1xuzatcsEAgGTTCan9gCm0OzZs81Pf/pT1ugqenp6zKJFi0xjY6P5yle+4kSKtZo80+rlvsHBQbW2tqq6utrZ5na7VV1drZaWlizOzB5nzpxRNBrNWKNgMKhVq1Y5a9TS0qKioiKtWLHCGVNdXS23260jR45M+ZynQjwel/S/X07c2tqqoaGhjHVavHixKioqMtbp5ptvVigUcsbU1NQokUjo9OnTUzj7qZFKpbR792719fUpEomwRldRV1en2trajDWR+Pc0mabVF8xevHhRqVQq439kSQqFQnrrrbeyNCu7RKNRSbrqGo3ui0ajKi0tzdjv8XhUXFzsjJlJ0um0tm7dqi9/+ctasmSJpJE18Pl8Kioqyhj78XW62jqO7pspTp06pUgkooGBARUWFmrPnj2qqqrSyZMnWaOP2L17t1577TUdO3bsin38e5o80ypSwLWoq6vT66+/rt/+9rfZnoqVbrrpJp08eVLxeFz//u//rg0bNqi5uTnb07JKR0eHtmzZosbGRuXm5mZ7Op8p0+rlvjlz5ignJ+eKK2ZisZjC4XCWZmWX0XX4tDUKh8Pq7OzM2D88PKyurq4Zt46bN2/W3r17dfDgQc2fP9/ZHg6HNTg4qO7u7ozxH1+nq63j6L6ZwufzaeHChVq+fLnq6+u1bNky/fjHP2aNPqK1tVWdnZ360pe+JI/HI4/Ho+bmZj355JPyeDwKhUKs1SSZVpHy+Xxavny5mpqanG3pdFpNTU2KRCJZnJk9KisrFQ6HM9YokUjoyJEjzhpFIhF1d3ertbXVGXPgwAGl02mtWrVqyuc8GYwx2rx5s/bs2aMDBw6osrIyY//y5cvl9Xoz1qmtrU3t7e0Z63Tq1KmMoDc2NioQCKiqqmpqDiQL0um0kskka/QRa9as0alTp3Ty5EnntmLFCq1fv975b9ZqkmT7yo3x2r17t/H7/ebZZ581b7zxhrn//vtNUVFRxhUzM11PT485ceKEOXHihJFk/umf/smcOHHCvPfee8aYkUvQi4qKzK9+9Svz+9//3tx+++1XvQT9lltuMUeOHDG//e1vzaJFi2bUJeibNm0ywWDQvPLKK+bChQvO7fLly86YBx54wFRUVJgDBw6Y48ePm0gkYiKRiLN/9JLhtWvXmpMnT5pf//rXZu7cuTPqkuFHHnnENDc3mzNnzpjf//735pFHHjEul8v85je/McawRp/mo1f3GcNaTZZpFyljjPnnf/5nU1FRYXw+n7n11lvN4cOHsz2lKXXw4EEj6Yrbhg0bjDEjl6E/9thjJhQKGb/fb9asWWPa2toyfselS5fM3XffbQoLC00gEDD33nuv6enpycLRTI6rrY8k88wzzzhj+vv7zd///d+b2bNnm/z8fPPNb37TXLhwIeP3nD171tx2220mLy/PzJkzx3znO98xQ0NDU3w0k+e+++4zn/vc54zP5zNz5841a9ascQJlDGv0aT4eKdZqcvCnOgAA1ppW70kBAD5biBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALDW/wfizIsE7uV+0QAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor([[-0.4965],\n",
       "          [-0.6202]], grad_fn=<MulBackward0>),\n",
       "  tensor([[0.7176],\n",
       "          [0.7136]], grad_fn=<SoftplusBackward0>)),\n",
       " tensor([[0.0784],\n",
       "         [0.3349]], grad_fn=<AddmmBackward0>))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#定义模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc_statu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 128),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.fc_mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Tanh(),\n",
    "        )\n",
    "\n",
    "        self.fc_std = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Softplus(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.fc_statu(state)\n",
    "\n",
    "        mu = self.fc_mu(state) * 2.0\n",
    "        std = self.fc_std(state)\n",
    "\n",
    "        return mu, std\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "model_td = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 1),\n",
    ")\n",
    "\n",
    "model(torch.randn(2, 3)), model_td(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.33349609375"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "\n",
    "\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    mu, std = model(state)\n",
    "\n",
    "    #根据概率选择一个动作\n",
    "    #action = random.normalvariate(mu=mu.item(), sigma=std.item())\n",
    "    action = torch.distributions.Normal(mu, std).sample().item()\n",
    "\n",
    "    return action\n",
    "\n",
    "\n",
    "get_action([1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1930/4114990845.py:31: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  states = torch.FloatTensor(states).reshape(-1, 3)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(tensor([[-8.9559e-01, -4.4487e-01,  9.7730e-01],\n",
       "         [-8.7855e-01, -4.7766e-01,  7.3911e-01],\n",
       "         [-8.6885e-01, -4.9508e-01,  3.9879e-01],\n",
       "         [-8.7131e-01, -4.9073e-01, -1.0008e-01],\n",
       "         [-8.8767e-01, -4.6047e-01, -6.8792e-01],\n",
       "         [-9.1079e-01, -4.1287e-01, -1.0584e+00],\n",
       "         [-9.3704e-01, -3.4922e-01, -1.3772e+00],\n",
       "         [-9.6465e-01, -2.6352e-01, -1.8015e+00],\n",
       "         [-9.8755e-01, -1.5731e-01, -2.1741e+00],\n",
       "         [-9.9946e-01, -3.2903e-02, -2.5011e+00],\n",
       "         [-9.9455e-01,  1.0424e-01, -2.7467e+00],\n",
       "         [-9.7055e-01,  2.4088e-01, -2.7770e+00],\n",
       "         [-9.2739e-01,  3.7410e-01, -2.8030e+00],\n",
       "         [-8.6870e-01,  4.9535e-01, -2.6962e+00],\n",
       "         [-8.0322e-01,  5.9568e-01, -2.3976e+00],\n",
       "         [-7.3322e-01,  6.8000e-01, -2.1928e+00],\n",
       "         [-6.6792e-01,  7.4423e-01, -1.8325e+00],\n",
       "         [-6.1176e-01,  7.9104e-01, -1.4627e+00],\n",
       "         [-5.6650e-01,  8.2406e-01, -1.1206e+00],\n",
       "         [-5.4571e-01,  8.3798e-01, -5.0036e-01],\n",
       "         [-5.5110e-01,  8.3444e-01,  1.2911e-01],\n",
       "         [-5.8332e-01,  8.1224e-01,  7.8252e-01],\n",
       "         [-6.4206e-01,  7.6666e-01,  1.4873e+00],\n",
       "         [-7.1537e-01,  6.9874e-01,  1.9996e+00],\n",
       "         [-8.0077e-01,  5.9897e-01,  2.6285e+00],\n",
       "         [-8.8507e-01,  4.6546e-01,  3.1613e+00],\n",
       "         [-9.5649e-01,  2.9175e-01,  3.7619e+00],\n",
       "         [-9.9613e-01,  8.7854e-02,  4.1618e+00],\n",
       "         [-9.9171e-01, -1.2846e-01,  4.3358e+00],\n",
       "         [-9.3810e-01, -3.4638e-01,  4.4977e+00],\n",
       "         [-8.3854e-01, -5.4483e-01,  4.4497e+00],\n",
       "         [-7.0743e-01, -7.0679e-01,  4.1751e+00],\n",
       "         [-5.5520e-01, -8.3172e-01,  3.9450e+00],\n",
       "         [-4.0672e-01, -9.1355e-01,  3.3946e+00],\n",
       "         [-2.7016e-01, -9.6281e-01,  2.9061e+00],\n",
       "         [-1.5598e-01, -9.8776e-01,  2.3388e+00],\n",
       "         [-6.6354e-02, -9.9780e-01,  1.8044e+00],\n",
       "         [-7.8938e-03, -9.9997e-01,  1.1702e+00],\n",
       "         [ 6.0994e-03, -9.9998e-01,  2.7987e-01],\n",
       "         [-1.1161e-02, -9.9994e-01, -3.4522e-01],\n",
       "         [-6.1799e-02, -9.9809e-01, -1.0135e+00],\n",
       "         [-1.4973e-01, -9.8873e-01, -1.7691e+00],\n",
       "         [-2.8050e-01, -9.5985e-01, -2.6804e+00],\n",
       "         [-4.4377e-01, -8.9614e-01, -3.5097e+00],\n",
       "         [-6.2899e-01, -7.7741e-01, -4.4091e+00],\n",
       "         [-8.0985e-01, -5.8663e-01, -5.2729e+00],\n",
       "         [-9.4681e-01, -3.2181e-01, -5.9852e+00],\n",
       "         [-1.0000e+00, -1.3083e-03, -6.5266e+00],\n",
       "         [-9.4273e-01,  3.3355e-01, -6.8275e+00],\n",
       "         [-7.7689e-01,  6.2964e-01, -6.8204e+00],\n",
       "         [-5.2910e-01,  8.4856e-01, -6.6434e+00],\n",
       "         [-2.4489e-01,  9.6955e-01, -6.2026e+00],\n",
       "         [ 3.5032e-02,  9.9939e-01, -5.6490e+00],\n",
       "         [ 2.8666e-01,  9.5803e-01, -5.1140e+00],\n",
       "         [ 4.9723e-01,  8.6762e-01, -4.5934e+00],\n",
       "         [ 6.6876e-01,  7.4348e-01, -4.2427e+00],\n",
       "         [ 7.9957e-01,  6.0057e-01, -3.8808e+00],\n",
       "         [ 8.9607e-01,  4.4391e-01, -3.6852e+00],\n",
       "         [ 9.6038e-01,  2.7868e-01, -3.5507e+00],\n",
       "         [ 9.9423e-01,  1.0730e-01, -3.4983e+00],\n",
       "         [ 9.9746e-01, -7.1285e-02, -3.5770e+00],\n",
       "         [ 9.6647e-01, -2.5677e-01, -3.7667e+00],\n",
       "         [ 8.9095e-01, -4.5409e-01, -4.2335e+00],\n",
       "         [ 7.5634e-01, -6.5418e-01, -4.8348e+00],\n",
       "         [ 5.4905e-01, -8.3579e-01, -5.5296e+00],\n",
       "         [ 2.5922e-01, -9.6582e-01, -6.3802e+00],\n",
       "         [-1.0649e-01, -9.9431e-01, -7.3781e+00],\n",
       "         [-4.8529e-01, -8.7435e-01, -8.0000e+00],\n",
       "         [-7.8747e-01, -6.1635e-01, -8.0000e+00],\n",
       "         [-9.6533e-01, -2.6105e-01, -8.0000e+00],\n",
       "         [-9.9078e-01,  1.3548e-01, -8.0000e+00],\n",
       "         [-8.5981e-01,  5.1061e-01, -8.0000e+00],\n",
       "         [-5.9801e-01,  8.0149e-01, -7.8777e+00],\n",
       "         [-2.6018e-01,  9.6556e-01, -7.5562e+00],\n",
       "         [ 9.0400e-02,  9.9591e-01, -7.0746e+00],\n",
       "         [ 4.0810e-01,  9.1294e-01, -6.5970e+00],\n",
       "         [ 6.6761e-01,  7.4451e-01, -6.2123e+00],\n",
       "         [ 8.5419e-01,  5.1996e-01, -5.8600e+00],\n",
       "         [ 9.6683e-01,  2.5544e-01, -5.7701e+00],\n",
       "         [ 9.9937e-01, -3.5614e-02, -5.8785e+00],\n",
       "         [ 9.4174e-01, -3.3635e-01, -6.1482e+00],\n",
       "         [ 7.7879e-01, -6.2728e-01, -6.7005e+00],\n",
       "         [ 4.9711e-01, -8.6769e-01, -7.4495e+00],\n",
       "         [ 1.1997e-01, -9.9278e-01, -8.0000e+00],\n",
       "         [-2.7610e-01, -9.6113e-01, -8.0000e+00],\n",
       "         [-6.2859e-01, -7.7774e-01, -8.0000e+00],\n",
       "         [-8.8183e-01, -4.7156e-01, -8.0000e+00],\n",
       "         [-9.9586e-01, -9.0934e-02, -8.0000e+00],\n",
       "         [-9.5266e-01,  3.0405e-01, -8.0000e+00],\n",
       "         [-7.5905e-01,  6.5103e-01, -8.0000e+00],\n",
       "         [-4.5613e-01,  8.8991e-01, -7.7642e+00],\n",
       "         [-1.0581e-01,  9.9439e-01, -7.3528e+00],\n",
       "         [ 2.3360e-01,  9.7233e-01, -6.8356e+00],\n",
       "         [ 5.2787e-01,  8.4932e-01, -6.4063e+00],\n",
       "         [ 7.5756e-01,  6.5277e-01, -6.0693e+00],\n",
       "         [ 9.1333e-01,  4.0722e-01, -5.8367e+00],\n",
       "         [ 9.9118e-01,  1.3251e-01, -5.7300e+00],\n",
       "         [ 9.8691e-01, -1.6124e-01, -5.8971e+00],\n",
       "         [ 8.8975e-01, -4.5645e-01, -6.2411e+00],\n",
       "         [ 6.8512e-01, -7.2843e-01, -6.8404e+00],\n",
       "         [ 3.6296e-01, -9.3181e-01, -7.6667e+00],\n",
       "         [-2.8558e-02, -9.9959e-01, -8.0000e+00],\n",
       "         [-4.1556e-01, -9.0956e-01, -8.0000e+00],\n",
       "         [-7.3696e-01, -6.7594e-01, -8.0000e+00],\n",
       "         [-9.4201e-01, -3.3559e-01, -8.0000e+00],\n",
       "         [-9.9833e-01,  5.7733e-02, -8.0000e+00],\n",
       "         [-8.9704e-01,  4.4194e-01, -8.0000e+00],\n",
       "         [-6.5587e-01,  7.5487e-01, -7.9538e+00],\n",
       "         [-3.2664e-01,  9.4515e-01, -7.6519e+00],\n",
       "         [ 2.6614e-02,  9.9965e-01, -7.1872e+00],\n",
       "         [ 3.5394e-01,  9.3527e-01, -6.7032e+00],\n",
       "         [ 6.2459e-01,  7.8095e-01, -6.2567e+00],\n",
       "         [ 8.2480e-01,  5.6543e-01, -5.9046e+00],\n",
       "         [ 9.5174e-01,  3.0690e-01, -5.7805e+00],\n",
       "         [ 9.9981e-01,  1.9411e-02, -5.8504e+00],\n",
       "         [ 9.6029e-01, -2.7901e-01, -6.0436e+00],\n",
       "         [ 8.2035e-01, -5.7187e-01, -6.5203e+00],\n",
       "         [ 5.6428e-01, -8.2559e-01, -7.2492e+00],\n",
       "         [ 1.9823e-01, -9.8015e-01, -8.0000e+00],\n",
       "         [-1.9910e-01, -9.7998e-01, -8.0000e+00],\n",
       "         [-5.6501e-01, -8.2509e-01, -8.0000e+00],\n",
       "         [-8.4171e-01, -5.3993e-01, -8.0000e+00],\n",
       "         [-9.8553e-01, -1.6953e-01, -8.0000e+00],\n",
       "         [-9.7375e-01,  2.2763e-01, -8.0000e+00],\n",
       "         [-8.0824e-01,  5.8886e-01, -8.0000e+00],\n",
       "         [-5.2118e-01,  8.5345e-01, -7.8584e+00],\n",
       "         [-1.7221e-01,  9.8506e-01, -7.5032e+00],\n",
       "         [ 1.7917e-01,  9.8382e-01, -7.0644e+00],\n",
       "         [ 4.8946e-01,  8.7202e-01, -6.6265e+00],\n",
       "         [ 7.3330e-01,  6.7991e-01, -6.2337e+00],\n",
       "         [ 9.0176e-01,  4.3224e-01, -6.0132e+00],\n",
       "         [ 9.8914e-01,  1.4699e-01, -5.9890e+00],\n",
       "         [ 9.8739e-01, -1.5829e-01, -6.1295e+00],\n",
       "         [ 8.8508e-01, -4.6543e-01, -6.5034e+00],\n",
       "         [ 6.7015e-01, -7.4223e-01, -7.0453e+00],\n",
       "         [ 3.3409e-01, -9.4254e-01, -7.8755e+00],\n",
       "         [-5.9331e-02, -9.9824e-01, -8.0000e+00],\n",
       "         [-4.4338e-01, -8.9633e-01, -8.0000e+00],\n",
       "         [-7.5743e-01, -6.5292e-01, -8.0000e+00],\n",
       "         [-9.5190e-01, -3.0642e-01, -8.0000e+00],\n",
       "         [-9.9608e-01,  8.8454e-02, -8.0000e+00],\n",
       "         [-8.8301e-01,  4.6936e-01, -8.0000e+00],\n",
       "         [-6.3311e-01,  7.7406e-01, -7.9332e+00],\n",
       "         [-2.9845e-01,  9.5443e-01, -7.6501e+00],\n",
       "         [ 5.7442e-02,  9.9835e-01, -7.2108e+00],\n",
       "         [ 3.8403e-01,  9.2332e-01, -6.7337e+00],\n",
       "         [ 6.5144e-01,  7.5870e-01, -6.3064e+00],\n",
       "         [ 8.4755e-01,  5.3071e-01, -6.0374e+00],\n",
       "         [ 9.6543e-01,  2.6068e-01, -5.9144e+00],\n",
       "         [ 9.9942e-01, -3.4134e-02, -5.9573e+00],\n",
       "         [ 9.3996e-01, -3.4129e-01, -6.2829e+00],\n",
       "         [ 7.7110e-01, -6.3671e-01, -6.8388e+00],\n",
       "         [ 4.7921e-01, -8.7770e-01, -7.6164e+00],\n",
       "         [ 9.9587e-02, -9.9503e-01, -8.0000e+00],\n",
       "         [-2.9576e-01, -9.5526e-01, -8.0000e+00],\n",
       "         [-6.4441e-01, -7.6468e-01, -8.0000e+00],\n",
       "         [-8.9132e-01, -4.5338e-01, -8.0000e+00],\n",
       "         [-9.9751e-01, -7.0490e-02, -8.0000e+00],\n",
       "         [-9.4622e-01,  3.2352e-01, -8.0000e+00],\n",
       "         [-7.4554e-01,  6.6646e-01, -8.0000e+00],\n",
       "         [-4.3727e-01,  8.9933e-01, -7.7758e+00],\n",
       "         [-8.4253e-02,  9.9644e-01, -7.3641e+00],\n",
       "         [ 2.5852e-01,  9.6601e-01, -6.9168e+00],\n",
       "         [ 5.5263e-01,  8.3343e-01, -6.4806e+00],\n",
       "         [ 7.7762e-01,  6.2873e-01, -6.1071e+00],\n",
       "         [ 9.2749e-01,  3.7384e-01, -5.9356e+00],\n",
       "         [ 9.9632e-01,  8.5741e-02, -5.9459e+00],\n",
       "         [ 9.7519e-01, -2.2139e-01, -6.1816e+00],\n",
       "         [ 8.4957e-01, -5.2747e-01, -6.6477e+00],\n",
       "         [ 6.0631e-01, -7.9523e-01, -7.2755e+00],\n",
       "         [ 2.4877e-01, -9.6856e-01, -8.0000e+00],\n",
       "         [-1.4805e-01, -9.8898e-01, -8.0000e+00],\n",
       "         [-5.2149e-01, -8.5326e-01, -8.0000e+00],\n",
       "         [-8.1260e-01, -5.8283e-01, -8.0000e+00],\n",
       "         [-9.7541e-01, -2.2038e-01, -8.0000e+00],\n",
       "         [-9.8424e-01,  1.7686e-01, -8.0000e+00],\n",
       "         [-8.3767e-01,  5.4618e-01, -8.0000e+00],\n",
       "         [-5.6473e-01,  8.2528e-01, -7.8579e+00],\n",
       "         [-2.2131e-01,  9.7520e-01, -7.5389e+00],\n",
       "         [ 1.3090e-01,  9.9140e-01, -7.0887e+00],\n",
       "         [ 4.4619e-01,  8.9494e-01, -6.6244e+00],\n",
       "         [ 6.9983e-01,  7.1431e-01, -6.2532e+00],\n",
       "         [ 8.7965e-01,  4.7563e-01, -5.9991e+00],\n",
       "         [ 9.7990e-01,  1.9947e-01, -5.8972e+00],\n",
       "         [ 9.9485e-01, -1.0139e-01, -6.0476e+00],\n",
       "         [ 9.1281e-01, -4.0839e-01, -6.3826e+00],\n",
       "         [ 7.1843e-01, -6.9560e-01, -6.9713e+00],\n",
       "         [ 4.0149e-01, -9.1586e-01, -7.7680e+00],\n",
       "         [ 1.3141e-02, -9.9991e-01, -8.0000e+00],\n",
       "         [-3.7728e-01, -9.2610e-01, -8.0000e+00],\n",
       "         [-7.0814e-01, -7.0607e-01, -8.0000e+00],\n",
       "         [-9.2720e-01, -3.7457e-01, -8.0000e+00],\n",
       "         [-9.9987e-01,  1.6061e-02, -8.0000e+00],\n",
       "         [-9.1469e-01,  4.0416e-01, -8.0000e+00],\n",
       "         [-6.8615e-01,  7.2746e-01, -7.9711e+00],\n",
       "         [-3.6152e-01,  9.3236e-01, -7.7255e+00],\n",
       "         [-3.5900e-03,  9.9999e-01, -7.3262e+00],\n",
       "         [ 3.3181e-01,  9.4335e-01, -6.8362e+00],\n",
       "         [ 6.1011e-01,  7.9231e-01, -6.3597e+00],\n",
       "         [ 8.1857e-01,  5.7440e-01, -6.0544e+00]]),\n",
       " tensor([[ -7.2813],\n",
       "         [ -7.0433],\n",
       "         [ -6.9003],\n",
       "         [ -6.9131],\n",
       "         [ -7.1393],\n",
       "         [ -7.4886],\n",
       "         [ -7.9462],\n",
       "         [ -8.5911],\n",
       "         [ -9.3767],\n",
       "         [-10.2916],\n",
       "         [ -9.9793],\n",
       "         [ -9.1733],\n",
       "         [ -8.3945],\n",
       "         [ -7.6092],\n",
       "         [ -6.8449],\n",
       "         [ -6.2123],\n",
       "         [ -5.6375],\n",
       "         [ -5.1855],\n",
       "         [ -4.8477],\n",
       "         [ -4.6391],\n",
       "         [ -4.6435],\n",
       "         [ -4.8736],\n",
       "         [ -5.3651],\n",
       "         [ -6.0075],\n",
       "         [ -6.9381],\n",
       "         [ -8.0642],\n",
       "         [ -9.5137],\n",
       "         [-11.0572],\n",
       "         [-10.9596],\n",
       "         [ -9.7972],\n",
       "         [ -8.5621],\n",
       "         [ -7.3010],\n",
       "         [ -6.2196],\n",
       "         [ -5.1128],\n",
       "         [ -4.2473],\n",
       "         [ -3.5329],\n",
       "         [ -3.0066],\n",
       "         [ -2.6301],\n",
       "         [ -2.4568],\n",
       "         [ -2.5148],\n",
       "         [ -2.7682],\n",
       "         [ -3.2764],\n",
       "         [ -4.1604],\n",
       "         [ -5.3574],\n",
       "         [ -7.0148],\n",
       "         [ -9.1074],\n",
       "         [-11.5046],\n",
       "         [-14.1250],\n",
       "         [-12.5127],\n",
       "         [-10.7098],\n",
       "         [ -8.9449],\n",
       "         [ -7.1544],\n",
       "         [ -5.5517],\n",
       "         [ -4.2556],\n",
       "         [ -3.2172],\n",
       "         [ -2.5044],\n",
       "         [ -1.9239],\n",
       "         [ -1.5714],\n",
       "         [ -1.3416],\n",
       "         [ -1.2365],\n",
       "         [ -1.2854],\n",
       "         [ -1.4896],\n",
       "         [ -2.0174],\n",
       "         [ -2.8479],\n",
       "         [ -4.0391],\n",
       "         [ -5.7864],\n",
       "         [ -8.2608],\n",
       "         [-10.7195],\n",
       "         [-12.5410],\n",
       "         [-14.6835],\n",
       "         [-15.4376],\n",
       "         [-13.1927],\n",
       "         [-11.1014],\n",
       "         [ -9.0758],\n",
       "         [ -7.1995],\n",
       "         [ -5.6795],\n",
       "         [ -4.5665],\n",
       "         [ -3.7370],\n",
       "         [ -3.4001],\n",
       "         [ -3.4595],\n",
       "         [ -3.9018],\n",
       "         [ -4.9529],\n",
       "         [ -6.6558],\n",
       "         [ -8.5075],\n",
       "         [ -9.8279],\n",
       "         [-11.4684],\n",
       "         [-13.4293],\n",
       "         [-15.7090],\n",
       "         [-14.4279],\n",
       "         [-12.3206],\n",
       "         [-10.2110],\n",
       "         [ -8.2204],\n",
       "         [ -6.4588],\n",
       "         [ -5.1377],\n",
       "         [ -4.1925],\n",
       "         [ -3.5843],\n",
       "         [ -3.3041],\n",
       "         [ -3.5060],\n",
       "         [ -4.1227],\n",
       "         [ -5.3484],\n",
       "         [ -7.3199],\n",
       "         [ -8.9616],\n",
       "         [-10.4004],\n",
       "         [-12.1605],\n",
       "         [-14.2403],\n",
       "         [-15.9137],\n",
       "         [-13.6065],\n",
       "         [-11.5559],\n",
       "         [ -9.4813],\n",
       "         [ -7.5532],\n",
       "         [ -5.9579],\n",
       "         [ -4.7202],\n",
       "         [ -3.8516],\n",
       "         [ -3.4428],\n",
       "         [ -3.4250],\n",
       "         [ -3.7356],\n",
       "         [ -4.6260],\n",
       "         [ -6.2024],\n",
       "         [ -8.2833],\n",
       "         [ -9.5405],\n",
       "         [-11.1173],\n",
       "         [-13.0150],\n",
       "         [-15.2323],\n",
       "         [-14.8830],\n",
       "         [-12.7139],\n",
       "         [-10.6693],\n",
       "         [ -8.6749],\n",
       "         [ -6.9285],\n",
       "         [ -5.5163],\n",
       "         [ -4.4486],\n",
       "         [ -3.8197],\n",
       "         [ -3.6114],\n",
       "         [ -3.7853],\n",
       "         [ -4.4655],\n",
       "         [ -5.6664],\n",
       "         [ -7.7196],\n",
       "         [ -9.0602],\n",
       "         [-10.5256],\n",
       "         [-12.3088],\n",
       "         [-14.4138],\n",
       "         [-15.7248],\n",
       "         [-13.4421],\n",
       "         [-11.3886],\n",
       "         [ -9.3671],\n",
       "         [ -7.4930],\n",
       "         [ -5.9219],\n",
       "         [ -4.7230],\n",
       "         [ -3.9614],\n",
       "         [ -3.5700],\n",
       "         [ -3.5541],\n",
       "         [ -4.0727],\n",
       "         [ -5.1574],\n",
       "         [ -6.9513],\n",
       "         [ -8.5680],\n",
       "         [ -9.9036],\n",
       "         [-11.5616],\n",
       "         [-13.5382],\n",
       "         [-15.8353],\n",
       "         [-14.3116],\n",
       "         [-12.2218],\n",
       "         [-10.1433],\n",
       "         [ -8.1665],\n",
       "         [ -6.5021],\n",
       "         [ -5.1735],\n",
       "         [ -4.1960],\n",
       "         [ -3.6737],\n",
       "         [ -3.5468],\n",
       "         [ -3.8751],\n",
       "         [ -4.7302],\n",
       "         [ -6.1405],\n",
       "         [ -8.1448],\n",
       "         [ -9.3596],\n",
       "         [-10.8957],\n",
       "         [-12.7510],\n",
       "         [-14.9263],\n",
       "         [-15.1881],\n",
       "         [-12.9762],\n",
       "         [-10.8915],\n",
       "         [ -8.9053],\n",
       "         [ -7.1007],\n",
       "         [ -5.6206],\n",
       "         [ -4.5468],\n",
       "         [ -3.8475],\n",
       "         [ -3.5221],\n",
       "         [ -3.6707],\n",
       "         [ -4.2543],\n",
       "         [ -5.4550],\n",
       "         [ -7.3783],\n",
       "         [ -8.8303],\n",
       "         [-10.2357],\n",
       "         [-11.9625],\n",
       "         [-14.0086],\n",
       "         [-16.1719],\n",
       "         [-13.8319],\n",
       "         [-11.7727],\n",
       "         [ -9.7387],\n",
       "         [ -7.8491],\n",
       "         [ -6.1949],\n",
       "         [ -4.8848],\n",
       "         [ -4.0431]]),\n",
       " tensor([[ 0.6364],\n",
       "         [ 0.1195],\n",
       "         [-0.8504],\n",
       "         [-1.4653],\n",
       "         [-0.1676],\n",
       "         [-0.0610],\n",
       "         [-1.0823],\n",
       "         [-1.1665],\n",
       "         [-1.3935],\n",
       "         [-1.4726],\n",
       "         [-0.7234],\n",
       "         [-1.3774],\n",
       "         [-1.1588],\n",
       "         [-0.4863],\n",
       "         [-1.6130],\n",
       "         [-0.9976],\n",
       "         [-1.2557],\n",
       "         [-1.6751],\n",
       "         [ 0.0149],\n",
       "         [ 0.0066],\n",
       "         [ 0.1839],\n",
       "         [ 0.6374],\n",
       "         [-0.4182],\n",
       "         [ 0.6988],\n",
       "         [ 0.5576],\n",
       "         [ 1.6764],\n",
       "         [ 1.2073],\n",
       "         [ 0.7206],\n",
       "         [ 1.7219],\n",
       "         [ 1.4120],\n",
       "         [ 0.8937],\n",
       "         [ 2.0042],\n",
       "         [ 0.4892],\n",
       "         [ 1.3108],\n",
       "         [ 1.0320],\n",
       "         [ 1.3765],\n",
       "         [ 0.7607],\n",
       "         [-0.9356],\n",
       "         [ 0.8327],\n",
       "         [ 0.5442],\n",
       "         [-0.0463],\n",
       "         [-1.1321],\n",
       "         [-0.7294],\n",
       "         [-1.5153],\n",
       "         [-1.8715],\n",
       "         [-1.8154],\n",
       "         [-2.1027],\n",
       "         [-2.1321],\n",
       "         [-1.6204],\n",
       "         [-1.9677],\n",
       "         [-1.3044],\n",
       "         [-1.1568],\n",
       "         [-1.4304],\n",
       "         [-1.3195],\n",
       "         [-2.1412],\n",
       "         [-1.3048],\n",
       "         [-1.6992],\n",
       "         [-1.3228],\n",
       "         [-1.0440],\n",
       "         [-1.0611],\n",
       "         [-0.9082],\n",
       "         [-1.8280],\n",
       "         [-1.7383],\n",
       "         [-1.3612],\n",
       "         [-1.4914],\n",
       "         [-1.8238],\n",
       "         [-1.7823],\n",
       "         [-1.8898],\n",
       "         [-1.7481],\n",
       "         [-1.8888],\n",
       "         [-1.8397],\n",
       "         [-1.7375],\n",
       "         [-1.8643],\n",
       "         [-1.6174],\n",
       "         [-1.7954],\n",
       "         [-2.0230],\n",
       "         [-1.3739],\n",
       "         [-2.5801],\n",
       "         [-2.1480],\n",
       "         [-1.6204],\n",
       "         [-2.1517],\n",
       "         [-1.8566],\n",
       "         [-1.6404],\n",
       "         [-1.8524],\n",
       "         [-1.8409],\n",
       "         [-1.8743],\n",
       "         [-2.1716],\n",
       "         [-1.8062],\n",
       "         [-1.9938],\n",
       "         [-1.6833],\n",
       "         [-1.7069],\n",
       "         [-1.5237],\n",
       "         [-2.3831],\n",
       "         [-2.1244],\n",
       "         [-1.7126],\n",
       "         [-1.3253],\n",
       "         [-1.7760],\n",
       "         [-1.4871],\n",
       "         [-1.7130],\n",
       "         [-1.8670],\n",
       "         [-1.8810],\n",
       "         [-1.9155],\n",
       "         [-1.7264],\n",
       "         [-1.8909],\n",
       "         [-1.9840],\n",
       "         [-1.9178],\n",
       "         [-1.9019],\n",
       "         [-1.7618],\n",
       "         [-1.6274],\n",
       "         [-1.7714],\n",
       "         [-1.6995],\n",
       "         [-1.5577],\n",
       "         [-2.0448],\n",
       "         [-2.4547],\n",
       "         [-1.3851],\n",
       "         [-1.7832],\n",
       "         [-2.2450],\n",
       "         [-2.0575],\n",
       "         [-1.7253],\n",
       "         [-1.7856],\n",
       "         [-1.7513],\n",
       "         [-1.9356],\n",
       "         [-2.0134],\n",
       "         [-1.8970],\n",
       "         [-2.1281],\n",
       "         [-1.8994],\n",
       "         [-2.1977],\n",
       "         [-2.2336],\n",
       "         [-1.7414],\n",
       "         [-1.9296],\n",
       "         [-2.0716],\n",
       "         [-1.6716],\n",
       "         [-1.7012],\n",
       "         [-1.2849],\n",
       "         [-1.8237],\n",
       "         [-2.0327],\n",
       "         [-1.6577],\n",
       "         [-2.0297],\n",
       "         [-1.7567],\n",
       "         [-2.2559],\n",
       "         [-1.9532],\n",
       "         [-1.9013],\n",
       "         [-1.9829],\n",
       "         [-1.8437],\n",
       "         [-1.8113],\n",
       "         [-1.7680],\n",
       "         [-2.2713],\n",
       "         [-1.8331],\n",
       "         [-1.5898],\n",
       "         [-1.9997],\n",
       "         [-2.0072],\n",
       "         [-2.1083],\n",
       "         [-1.8012],\n",
       "         [-2.0339],\n",
       "         [-1.6674],\n",
       "         [-2.0634],\n",
       "         [-1.9289],\n",
       "         [-2.0863],\n",
       "         [-1.8618],\n",
       "         [-1.8373],\n",
       "         [-1.7522],\n",
       "         [-2.0045],\n",
       "         [-1.9226],\n",
       "         [-1.6771],\n",
       "         [-2.0241],\n",
       "         [-1.9381],\n",
       "         [-2.4922],\n",
       "         [-2.2232],\n",
       "         [-1.5480],\n",
       "         [-1.4266],\n",
       "         [-2.1161],\n",
       "         [-1.8082],\n",
       "         [-1.9619],\n",
       "         [-1.9284],\n",
       "         [-1.8701],\n",
       "         [-2.1841],\n",
       "         [-1.7835],\n",
       "         [-2.0338],\n",
       "         [-1.8746],\n",
       "         [-1.8617],\n",
       "         [-2.1245],\n",
       "         [-1.8775],\n",
       "         [-1.6989],\n",
       "         [-2.0365],\n",
       "         [-1.7264],\n",
       "         [-1.8822],\n",
       "         [-1.8333],\n",
       "         [-2.0291],\n",
       "         [-1.9952],\n",
       "         [-1.8127],\n",
       "         [-2.1953],\n",
       "         [-1.9739],\n",
       "         [-1.7304],\n",
       "         [-1.8283],\n",
       "         [-2.1060],\n",
       "         [-2.3156],\n",
       "         [-1.7328],\n",
       "         [-1.5405],\n",
       "         [-1.9257],\n",
       "         [-1.7941]]),\n",
       " tensor([[-8.7855e-01, -4.7766e-01,  7.3911e-01],\n",
       "         [-8.6885e-01, -4.9508e-01,  3.9879e-01],\n",
       "         [-8.7131e-01, -4.9073e-01, -1.0008e-01],\n",
       "         [-8.8767e-01, -4.6047e-01, -6.8792e-01],\n",
       "         [-9.1079e-01, -4.1287e-01, -1.0584e+00],\n",
       "         [-9.3704e-01, -3.4922e-01, -1.3772e+00],\n",
       "         [-9.6465e-01, -2.6352e-01, -1.8015e+00],\n",
       "         [-9.8755e-01, -1.5731e-01, -2.1741e+00],\n",
       "         [-9.9946e-01, -3.2903e-02, -2.5011e+00],\n",
       "         [-9.9455e-01,  1.0424e-01, -2.7467e+00],\n",
       "         [-9.7055e-01,  2.4088e-01, -2.7770e+00],\n",
       "         [-9.2739e-01,  3.7410e-01, -2.8030e+00],\n",
       "         [-8.6870e-01,  4.9535e-01, -2.6962e+00],\n",
       "         [-8.0322e-01,  5.9568e-01, -2.3976e+00],\n",
       "         [-7.3322e-01,  6.8000e-01, -2.1928e+00],\n",
       "         [-6.6792e-01,  7.4423e-01, -1.8325e+00],\n",
       "         [-6.1176e-01,  7.9104e-01, -1.4627e+00],\n",
       "         [-5.6650e-01,  8.2406e-01, -1.1206e+00],\n",
       "         [-5.4571e-01,  8.3798e-01, -5.0036e-01],\n",
       "         [-5.5110e-01,  8.3444e-01,  1.2911e-01],\n",
       "         [-5.8332e-01,  8.1224e-01,  7.8252e-01],\n",
       "         [-6.4206e-01,  7.6666e-01,  1.4873e+00],\n",
       "         [-7.1537e-01,  6.9874e-01,  1.9996e+00],\n",
       "         [-8.0077e-01,  5.9897e-01,  2.6285e+00],\n",
       "         [-8.8507e-01,  4.6546e-01,  3.1613e+00],\n",
       "         [-9.5649e-01,  2.9175e-01,  3.7619e+00],\n",
       "         [-9.9613e-01,  8.7854e-02,  4.1618e+00],\n",
       "         [-9.9171e-01, -1.2846e-01,  4.3358e+00],\n",
       "         [-9.3810e-01, -3.4638e-01,  4.4977e+00],\n",
       "         [-8.3854e-01, -5.4483e-01,  4.4497e+00],\n",
       "         [-7.0743e-01, -7.0679e-01,  4.1751e+00],\n",
       "         [-5.5520e-01, -8.3172e-01,  3.9450e+00],\n",
       "         [-4.0672e-01, -9.1355e-01,  3.3946e+00],\n",
       "         [-2.7016e-01, -9.6281e-01,  2.9061e+00],\n",
       "         [-1.5598e-01, -9.8776e-01,  2.3388e+00],\n",
       "         [-6.6354e-02, -9.9780e-01,  1.8044e+00],\n",
       "         [-7.8938e-03, -9.9997e-01,  1.1702e+00],\n",
       "         [ 6.0994e-03, -9.9998e-01,  2.7987e-01],\n",
       "         [-1.1161e-02, -9.9994e-01, -3.4522e-01],\n",
       "         [-6.1799e-02, -9.9809e-01, -1.0135e+00],\n",
       "         [-1.4973e-01, -9.8873e-01, -1.7691e+00],\n",
       "         [-2.8050e-01, -9.5985e-01, -2.6804e+00],\n",
       "         [-4.4377e-01, -8.9614e-01, -3.5097e+00],\n",
       "         [-6.2899e-01, -7.7741e-01, -4.4091e+00],\n",
       "         [-8.0985e-01, -5.8663e-01, -5.2729e+00],\n",
       "         [-9.4681e-01, -3.2181e-01, -5.9852e+00],\n",
       "         [-1.0000e+00, -1.3083e-03, -6.5266e+00],\n",
       "         [-9.4273e-01,  3.3355e-01, -6.8275e+00],\n",
       "         [-7.7689e-01,  6.2964e-01, -6.8204e+00],\n",
       "         [-5.2910e-01,  8.4856e-01, -6.6434e+00],\n",
       "         [-2.4489e-01,  9.6955e-01, -6.2026e+00],\n",
       "         [ 3.5032e-02,  9.9939e-01, -5.6490e+00],\n",
       "         [ 2.8666e-01,  9.5803e-01, -5.1140e+00],\n",
       "         [ 4.9723e-01,  8.6762e-01, -4.5934e+00],\n",
       "         [ 6.6876e-01,  7.4348e-01, -4.2427e+00],\n",
       "         [ 7.9957e-01,  6.0057e-01, -3.8808e+00],\n",
       "         [ 8.9607e-01,  4.4391e-01, -3.6852e+00],\n",
       "         [ 9.6038e-01,  2.7868e-01, -3.5507e+00],\n",
       "         [ 9.9423e-01,  1.0730e-01, -3.4983e+00],\n",
       "         [ 9.9746e-01, -7.1285e-02, -3.5770e+00],\n",
       "         [ 9.6647e-01, -2.5677e-01, -3.7667e+00],\n",
       "         [ 8.9095e-01, -4.5409e-01, -4.2335e+00],\n",
       "         [ 7.5634e-01, -6.5418e-01, -4.8348e+00],\n",
       "         [ 5.4905e-01, -8.3579e-01, -5.5296e+00],\n",
       "         [ 2.5922e-01, -9.6582e-01, -6.3802e+00],\n",
       "         [-1.0649e-01, -9.9431e-01, -7.3781e+00],\n",
       "         [-4.8529e-01, -8.7435e-01, -8.0000e+00],\n",
       "         [-7.8747e-01, -6.1635e-01, -8.0000e+00],\n",
       "         [-9.6533e-01, -2.6105e-01, -8.0000e+00],\n",
       "         [-9.9078e-01,  1.3548e-01, -8.0000e+00],\n",
       "         [-8.5981e-01,  5.1061e-01, -8.0000e+00],\n",
       "         [-5.9801e-01,  8.0149e-01, -7.8777e+00],\n",
       "         [-2.6018e-01,  9.6556e-01, -7.5562e+00],\n",
       "         [ 9.0400e-02,  9.9591e-01, -7.0746e+00],\n",
       "         [ 4.0810e-01,  9.1294e-01, -6.5970e+00],\n",
       "         [ 6.6761e-01,  7.4451e-01, -6.2123e+00],\n",
       "         [ 8.5419e-01,  5.1996e-01, -5.8600e+00],\n",
       "         [ 9.6683e-01,  2.5544e-01, -5.7701e+00],\n",
       "         [ 9.9937e-01, -3.5614e-02, -5.8785e+00],\n",
       "         [ 9.4174e-01, -3.3635e-01, -6.1482e+00],\n",
       "         [ 7.7879e-01, -6.2728e-01, -6.7005e+00],\n",
       "         [ 4.9711e-01, -8.6769e-01, -7.4495e+00],\n",
       "         [ 1.1997e-01, -9.9278e-01, -8.0000e+00],\n",
       "         [-2.7610e-01, -9.6113e-01, -8.0000e+00],\n",
       "         [-6.2859e-01, -7.7774e-01, -8.0000e+00],\n",
       "         [-8.8183e-01, -4.7156e-01, -8.0000e+00],\n",
       "         [-9.9586e-01, -9.0934e-02, -8.0000e+00],\n",
       "         [-9.5266e-01,  3.0405e-01, -8.0000e+00],\n",
       "         [-7.5905e-01,  6.5103e-01, -8.0000e+00],\n",
       "         [-4.5613e-01,  8.8991e-01, -7.7642e+00],\n",
       "         [-1.0581e-01,  9.9439e-01, -7.3528e+00],\n",
       "         [ 2.3360e-01,  9.7233e-01, -6.8356e+00],\n",
       "         [ 5.2787e-01,  8.4932e-01, -6.4063e+00],\n",
       "         [ 7.5756e-01,  6.5277e-01, -6.0693e+00],\n",
       "         [ 9.1333e-01,  4.0722e-01, -5.8367e+00],\n",
       "         [ 9.9118e-01,  1.3251e-01, -5.7300e+00],\n",
       "         [ 9.8691e-01, -1.6124e-01, -5.8971e+00],\n",
       "         [ 8.8975e-01, -4.5645e-01, -6.2411e+00],\n",
       "         [ 6.8512e-01, -7.2843e-01, -6.8404e+00],\n",
       "         [ 3.6296e-01, -9.3181e-01, -7.6667e+00],\n",
       "         [-2.8558e-02, -9.9959e-01, -8.0000e+00],\n",
       "         [-4.1556e-01, -9.0956e-01, -8.0000e+00],\n",
       "         [-7.3696e-01, -6.7594e-01, -8.0000e+00],\n",
       "         [-9.4201e-01, -3.3559e-01, -8.0000e+00],\n",
       "         [-9.9833e-01,  5.7733e-02, -8.0000e+00],\n",
       "         [-8.9704e-01,  4.4194e-01, -8.0000e+00],\n",
       "         [-6.5587e-01,  7.5487e-01, -7.9538e+00],\n",
       "         [-3.2664e-01,  9.4515e-01, -7.6519e+00],\n",
       "         [ 2.6614e-02,  9.9965e-01, -7.1872e+00],\n",
       "         [ 3.5394e-01,  9.3527e-01, -6.7032e+00],\n",
       "         [ 6.2459e-01,  7.8095e-01, -6.2567e+00],\n",
       "         [ 8.2480e-01,  5.6543e-01, -5.9046e+00],\n",
       "         [ 9.5174e-01,  3.0690e-01, -5.7805e+00],\n",
       "         [ 9.9981e-01,  1.9411e-02, -5.8504e+00],\n",
       "         [ 9.6029e-01, -2.7901e-01, -6.0436e+00],\n",
       "         [ 8.2035e-01, -5.7187e-01, -6.5203e+00],\n",
       "         [ 5.6428e-01, -8.2559e-01, -7.2492e+00],\n",
       "         [ 1.9823e-01, -9.8015e-01, -8.0000e+00],\n",
       "         [-1.9910e-01, -9.7998e-01, -8.0000e+00],\n",
       "         [-5.6501e-01, -8.2509e-01, -8.0000e+00],\n",
       "         [-8.4171e-01, -5.3993e-01, -8.0000e+00],\n",
       "         [-9.8553e-01, -1.6953e-01, -8.0000e+00],\n",
       "         [-9.7375e-01,  2.2763e-01, -8.0000e+00],\n",
       "         [-8.0824e-01,  5.8886e-01, -8.0000e+00],\n",
       "         [-5.2118e-01,  8.5345e-01, -7.8584e+00],\n",
       "         [-1.7221e-01,  9.8506e-01, -7.5032e+00],\n",
       "         [ 1.7917e-01,  9.8382e-01, -7.0644e+00],\n",
       "         [ 4.8946e-01,  8.7202e-01, -6.6265e+00],\n",
       "         [ 7.3330e-01,  6.7991e-01, -6.2337e+00],\n",
       "         [ 9.0176e-01,  4.3224e-01, -6.0132e+00],\n",
       "         [ 9.8914e-01,  1.4699e-01, -5.9890e+00],\n",
       "         [ 9.8739e-01, -1.5829e-01, -6.1295e+00],\n",
       "         [ 8.8508e-01, -4.6543e-01, -6.5034e+00],\n",
       "         [ 6.7015e-01, -7.4223e-01, -7.0453e+00],\n",
       "         [ 3.3409e-01, -9.4254e-01, -7.8755e+00],\n",
       "         [-5.9331e-02, -9.9824e-01, -8.0000e+00],\n",
       "         [-4.4338e-01, -8.9633e-01, -8.0000e+00],\n",
       "         [-7.5743e-01, -6.5292e-01, -8.0000e+00],\n",
       "         [-9.5190e-01, -3.0642e-01, -8.0000e+00],\n",
       "         [-9.9608e-01,  8.8454e-02, -8.0000e+00],\n",
       "         [-8.8301e-01,  4.6936e-01, -8.0000e+00],\n",
       "         [-6.3311e-01,  7.7406e-01, -7.9332e+00],\n",
       "         [-2.9845e-01,  9.5443e-01, -7.6501e+00],\n",
       "         [ 5.7442e-02,  9.9835e-01, -7.2108e+00],\n",
       "         [ 3.8403e-01,  9.2332e-01, -6.7337e+00],\n",
       "         [ 6.5144e-01,  7.5870e-01, -6.3064e+00],\n",
       "         [ 8.4755e-01,  5.3071e-01, -6.0374e+00],\n",
       "         [ 9.6543e-01,  2.6068e-01, -5.9144e+00],\n",
       "         [ 9.9942e-01, -3.4134e-02, -5.9573e+00],\n",
       "         [ 9.3996e-01, -3.4129e-01, -6.2829e+00],\n",
       "         [ 7.7110e-01, -6.3671e-01, -6.8388e+00],\n",
       "         [ 4.7921e-01, -8.7770e-01, -7.6164e+00],\n",
       "         [ 9.9587e-02, -9.9503e-01, -8.0000e+00],\n",
       "         [-2.9576e-01, -9.5526e-01, -8.0000e+00],\n",
       "         [-6.4441e-01, -7.6468e-01, -8.0000e+00],\n",
       "         [-8.9132e-01, -4.5338e-01, -8.0000e+00],\n",
       "         [-9.9751e-01, -7.0490e-02, -8.0000e+00],\n",
       "         [-9.4622e-01,  3.2352e-01, -8.0000e+00],\n",
       "         [-7.4554e-01,  6.6646e-01, -8.0000e+00],\n",
       "         [-4.3727e-01,  8.9933e-01, -7.7758e+00],\n",
       "         [-8.4253e-02,  9.9644e-01, -7.3641e+00],\n",
       "         [ 2.5852e-01,  9.6601e-01, -6.9168e+00],\n",
       "         [ 5.5263e-01,  8.3343e-01, -6.4806e+00],\n",
       "         [ 7.7762e-01,  6.2873e-01, -6.1071e+00],\n",
       "         [ 9.2749e-01,  3.7384e-01, -5.9356e+00],\n",
       "         [ 9.9632e-01,  8.5741e-02, -5.9459e+00],\n",
       "         [ 9.7519e-01, -2.2139e-01, -6.1816e+00],\n",
       "         [ 8.4957e-01, -5.2747e-01, -6.6477e+00],\n",
       "         [ 6.0631e-01, -7.9523e-01, -7.2755e+00],\n",
       "         [ 2.4877e-01, -9.6856e-01, -8.0000e+00],\n",
       "         [-1.4805e-01, -9.8898e-01, -8.0000e+00],\n",
       "         [-5.2149e-01, -8.5326e-01, -8.0000e+00],\n",
       "         [-8.1260e-01, -5.8283e-01, -8.0000e+00],\n",
       "         [-9.7541e-01, -2.2038e-01, -8.0000e+00],\n",
       "         [-9.8424e-01,  1.7686e-01, -8.0000e+00],\n",
       "         [-8.3767e-01,  5.4618e-01, -8.0000e+00],\n",
       "         [-5.6473e-01,  8.2528e-01, -7.8579e+00],\n",
       "         [-2.2131e-01,  9.7520e-01, -7.5389e+00],\n",
       "         [ 1.3090e-01,  9.9140e-01, -7.0887e+00],\n",
       "         [ 4.4619e-01,  8.9494e-01, -6.6244e+00],\n",
       "         [ 6.9983e-01,  7.1431e-01, -6.2532e+00],\n",
       "         [ 8.7965e-01,  4.7563e-01, -5.9991e+00],\n",
       "         [ 9.7990e-01,  1.9947e-01, -5.8972e+00],\n",
       "         [ 9.9485e-01, -1.0139e-01, -6.0476e+00],\n",
       "         [ 9.1281e-01, -4.0839e-01, -6.3826e+00],\n",
       "         [ 7.1843e-01, -6.9560e-01, -6.9713e+00],\n",
       "         [ 4.0149e-01, -9.1586e-01, -7.7680e+00],\n",
       "         [ 1.3141e-02, -9.9991e-01, -8.0000e+00],\n",
       "         [-3.7728e-01, -9.2610e-01, -8.0000e+00],\n",
       "         [-7.0814e-01, -7.0607e-01, -8.0000e+00],\n",
       "         [-9.2720e-01, -3.7457e-01, -8.0000e+00],\n",
       "         [-9.9987e-01,  1.6061e-02, -8.0000e+00],\n",
       "         [-9.1469e-01,  4.0416e-01, -8.0000e+00],\n",
       "         [-6.8615e-01,  7.2746e-01, -7.9711e+00],\n",
       "         [-3.6152e-01,  9.3236e-01, -7.7255e+00],\n",
       "         [-3.5900e-03,  9.9999e-01, -7.3262e+00],\n",
       "         [ 3.3181e-01,  9.4335e-01, -6.8362e+00],\n",
       "         [ 6.1011e-01,  7.9231e-01, -6.3597e+00],\n",
       "         [ 8.1857e-01,  5.7440e-01, -6.0544e+00],\n",
       "         [ 9.5010e-01,  3.1194e-01, -5.8927e+00]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1]]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_data():\n",
    "    states = []\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    next_states = []\n",
    "    overs = []\n",
    "\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "        #记录数据样本\n",
    "        states.append(state)\n",
    "        rewards.append(reward)\n",
    "        actions.append(action)\n",
    "        next_states.append(next_state)\n",
    "        overs.append(over)\n",
    "\n",
    "        #更新游戏状态,开始下一个动作\n",
    "        state = next_state\n",
    "\n",
    "    #[b, 3]\n",
    "    states = torch.FloatTensor(states).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    rewards = torch.FloatTensor(rewards).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    actions = torch.FloatTensor(actions).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_states = torch.FloatTensor(next_states).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    overs = torch.LongTensor(overs).reshape(-1, 1)\n",
    "\n",
    "    return states, rewards, actions, next_states, overs\n",
    "\n",
    "\n",
    "get_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1572.7626745772932"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[5.43839184, 6.7140640000000005, 7.0544, 6.24, 4.0]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#优势函数\n",
    "def get_advantages(deltas):\n",
    "    advantages = []\n",
    "\n",
    "    #反向遍历deltas\n",
    "    s = 0.0\n",
    "    for delta in deltas[::-1]:\n",
    "        s = 0.9 * 0.9 * s + delta\n",
    "        advantages.append(s)\n",
    "\n",
    "    #逆序\n",
    "    advantages.reverse()\n",
    "    return advantages\n",
    "\n",
    "\n",
    "get_advantages(range(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "executionInfo": {
     "elapsed": 8251,
     "status": "ok",
     "timestamp": 1650011468229,
     "user": {
      "displayName": "Sam Lu",
      "userId": "15789059763790170725"
     },
     "user_tz": -480
    },
    "id": "BQXVYW2T_DcQ",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -1547.517992458506\n",
      "200 -1165.2231529567657\n",
      "400 -932.6529308034745\n",
      "600 -883.3680068060828\n",
      "800 -691.3652070559533\n",
      "1000 -583.2801139359487\n",
      "1200 -331.1195745940131\n",
      "1400 -417.8007034940112\n",
      "1600 -246.89518041823234\n",
      "1800 -315.67203052079265\n",
      "2000 -199.35527144494966\n",
      "2200 -199.31985557359923\n",
      "2400 -306.49490903038406\n",
      "2600 -540.0215114327877\n",
      "2800 -439.35343252282394\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "    optimizer_td = torch.optim.Adam(model_td.parameters(), lr=5e-3)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #玩N局游戏,每局游戏训练M次\n",
    "    for epoch in range(3000):\n",
    "        #玩一局游戏,得到数据\n",
    "        #states -> [b, 3]\n",
    "        #rewards -> [b, 1]\n",
    "        #actions -> [b, 1]\n",
    "        #next_states -> [b, 3]\n",
    "        #overs -> [b, 1]\n",
    "        states, rewards, actions, next_states, overs = get_data()\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        rewards = (rewards + 8) / 8\n",
    "\n",
    "        #计算values和targets\n",
    "        #[b, 3] -> [b, 1]\n",
    "        values = model_td(states)\n",
    "\n",
    "        #[b, 3] -> [b, 1]\n",
    "        targets = model_td(next_states).detach()\n",
    "        targets = targets * 0.98\n",
    "        targets *= (1 - overs)\n",
    "        targets += rewards\n",
    "\n",
    "        #计算优势,这里的advantages有点像是策略梯度里的reward_sum\n",
    "        #只是这里计算的不是reward,而是target和value的差\n",
    "        #[b, 1]\n",
    "        deltas = (targets - values).squeeze(dim=1).tolist()\n",
    "        advantages = get_advantages(deltas)\n",
    "        advantages = torch.FloatTensor(advantages).reshape(-1, 1)\n",
    "\n",
    "        #取出每一步动作的概率\n",
    "        #[b, 3] -> [b, 1],[b, 1]\n",
    "        mu, std = model(states)\n",
    "        #[b, 1]\n",
    "        old_probs = torch.distributions.Normal(mu, std)\n",
    "        old_probs = old_probs.log_prob(actions).exp().detach()\n",
    "\n",
    "        #每批数据反复训练10次\n",
    "        for _ in range(10):\n",
    "            #重新计算每一步动作的概率\n",
    "            #[b, 3] -> [b, 1],[b, 1]\n",
    "            mu, std = model(states)\n",
    "            #[b, 1]\n",
    "            new_probs = torch.distributions.Normal(mu, std)\n",
    "            new_probs = new_probs.log_prob(actions).exp()\n",
    "\n",
    "            #求出概率的变化\n",
    "            #[b, 1] - [b, 1] -> [b, 1]\n",
    "            ratios = new_probs / old_probs\n",
    "\n",
    "            #计算截断的和不截断的两份loss,取其中小的\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr1 = ratios * advantages\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr2 = torch.clamp(ratios, 0.8, 1.2) * advantages\n",
    "\n",
    "            loss = -torch.min(surr1, surr2)\n",
    "            loss = loss.mean()\n",
    "\n",
    "            #重新计算value,并计算时序差分loss\n",
    "            values = model_td(states)\n",
    "            loss_td = loss_fn(values, targets)\n",
    "\n",
    "            #更新参数\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            optimizer_td.zero_grad()\n",
    "            loss_td.backward()\n",
    "            optimizer_td.step()\n",
    "\n",
    "        if epoch % 200 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(10)]) / 10\n",
    "            print(epoch, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAh6klEQVR4nO3df3DU5aHv8c9ukt0khN0QMLumJEKPjJjywwoIW+/UTolEm1qtdK51OJRajo40OCA9TqVVHJ0zE8eeW6stYs84Fc8clQ49RSsFbRowVI0BI1F+mdpz0aTiJkjMbhLJZrP73D8se11FzZKw+2zyfs3sjPl+n908+5jsm+/uN7sOY4wRAAAWcmZ6AgAAfBoiBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwVsYitXHjRk2bNk35+flauHCh9u7dm6mpAAAslZFI/fa3v9W6det011136dVXX9XcuXNVXV2trq6uTEwHAGApRybeYHbhwoVasGCBfvWrX0mS4vG4ysvLdcstt+j2229P93QAAJbKTfc3HBwcVEtLi9avX5/Y5nQ6VVVVpaamptNeJxKJKBKJJL6Ox+Pq7u7W5MmT5XA4zvqcAQCjyxij3t5elZWVyen89Cf10h6p9957T7FYTD6fL2m7z+fTG2+8cdrr1NXV6e67707H9AAAadTR0aGpU6d+6v60R+pMrF+/XuvWrUt8HQqFVFFRoY6ODnk8ngzODABwJsLhsMrLyzVx4sTPHJf2SE2ZMkU5OTnq7OxM2t7Z2Sm/33/a67jdbrnd7k9s93g8RAoAstjnvWST9rP7XC6X5s2bp4aGhsS2eDyuhoYGBQKBdE8HAGCxjDzdt27dOq1YsULz58/XJZdcol/84hfq7+/XDTfckInpAAAslZFIXXfddTp+/Lg2bNigYDCoiy66SM8+++wnTqYAAIxvGfk7qZEKh8Pyer0KhUK8JgUAWWi4j+O8dx8AwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAa6UcqT179uiqq65SWVmZHA6HnnrqqaT9xhht2LBB5557rgoKClRVVaU333wzaUx3d7eWLVsmj8ej4uJirVy5Un19fSO6IwCAsSflSPX392vu3LnauHHjafffd999evDBB/Xwww+rublZEyZMUHV1tQYGBhJjli1bpkOHDqm+vl7bt2/Xnj17dNNNN535vQAAjE1mBCSZbdu2Jb6Ox+PG7/ebn/3sZ4ltPT09xu12myeffNIYY8zhw4eNJLNv377EmJ07dxqHw2HeeeedYX3fUChkJJlQKDSS6QMAMmS4j+Oj+prU0aNHFQwGVVVVldjm9Xq1cOFCNTU1SZKamppUXFys+fPnJ8ZUVVXJ6XSqubn5tLcbiUQUDoeTLgCAsW9UIxUMBiVJPp8vabvP50vsCwaDKi0tTdqfm5urkpKSxJiPq6urk9frTVzKy8tHc9oAAEtlxdl969evVygUSlw6OjoyPSUAQBqMaqT8fr8kqbOzM2l7Z2dnYp/f71dXV1fS/qGhIXV3dyfGfJzb7ZbH40m6AADGvlGN1PTp0+X3+9XQ0JDYFg6H1dzcrEAgIEkKBALq6elRS0tLYsyuXbsUj8e1cOHC0ZwOACDL5aZ6hb6+Pv3tb39LfH306FG1traqpKREFRUVWrt2rf7t3/5NM2bM0PTp03XnnXeqrKxM11xzjSTpwgsv1BVXXKEbb7xRDz/8sKLRqFavXq3vfve7KisrG7U7BgAYA1I9bXD37t1G0icuK1asMMZ8eBr6nXfeaXw+n3G73Wbx4sWmra0t6TZOnDhhrr/+elNUVGQ8Ho+54YYbTG9v76ifuggAsNNwH8cdxhiTwUaekXA4LK/Xq1AoxOtTAJCFhvs4nhVn9wEAxiciBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwVspvMAtg9ER7etTf1qZIMKjYBx/IkZcn15QpKvynf1L+1KlyOByZniKQUUQKSDNjjOIDA3r/xRf1Xn29IseOKdbfLzM0JOXkKCc/X7nFxSpetEilNTXKmzyZWGHcIlJAmkXff1/v/Od/qnv3bunj7+8ciynW369Yf786f/979b72mqauXKmiykpChXGJ16SANBrs7taxxx/X+42NnwzUxxmjD/72N3X8x3/og//5n/RMELAMkQLSJB6N6sSf/6zu55+XicWGfb2TR4/q2OOPa/DEibM4O8BORApIk0gwqGP/9V8y0WjK1w23tOjErl0pxQ0YC4gUkCad//3fI7p+9549ir7//ijNBsgORApIk74jR0Z0/YG331bs5MlRmg2QHYgUkAZ9bW2jE5h4XObzTrgAxhAiBaTBiYYGDfX0jPh2YpHIyCcDZBEiBWSROE/3YZwhUkAWiQ8MZHoKQFoRKSCL8HQfxhsiBWQRjqQw3hApIIsQKYw3RArIIkQK4w2RArJIpKsr01MA0opIAWmQ/4UvyOFyjfh2Qs3NozAbIHsQKSANCqZPV05+fqanAWQdIgWkgdPlkvjQQiBlRApIA6fLJTn5dQNSxW8NkAZOt5uPfwfOAJEC0sDBkRRwRvitAdLA6XJxJAWcASIFpAGvSQFnht8aIA1G7TUpY2RisZHfDpAliBSQBg6nc9ROQTeDg6NyO0A2IFJAFjHGKM7HdWAcIVJANjGGz5TCuEKkgGxijAyRwjhCpIAsYoxRnNekMI4QKSCbGKMYnymFcYRIAdnEGM7uw7hCpIAsYuJxRbu7Mz0NIG2IFJAmBV/84ohvwwwOKtzaOvLJAFmCSAFpUvSlL2V6CkDWIVJAmvDJvEDqiBSQJk4iBaSMSAFpQqSA1BEpIE14ug9IHZEC0oQjKSB1RApIk9GMlDFm1G4LsBmRAtIkx+0enRuKxz+8AOMAkQLSZZQ+Pj4ei/HpvBg3iBSQZczQkOLRaKanAaQFkQKyjBkakhkayvQ0gLQgUkC2icWIFMYNIgVkmThHUhhHiBSQZXi6D+MJkQKyjOHpPowjRApIE4fTKWdBwYhvJ37ypGIffDAKMwLsR6SANHHm58tz8cUjvp3Iu+9qoKNjFGYE2I9IAenicMjpcmV6FkBWSSlSdXV1WrBggSZOnKjS0lJdc801amtrSxozMDCg2tpaTZ48WUVFRVq6dKk6OzuTxrS3t6umpkaFhYUqLS3VbbfdpiGeY8cY53A45Bytt0YCxomUItXY2Kja2lq9/PLLqq+vVzQa1ZIlS9Tf358Yc+utt+qZZ57R1q1b1djYqGPHjunaa69N7I/FYqqpqdHg4KBeeuklPfbYY9q8ebM2bNgwevcKsBFHUkDKHGYEb6d8/PhxlZaWqrGxUV/96lcVCoV0zjnn6IknntB3vvMdSdIbb7yhCy+8UE1NTVq0aJF27typb37zmzp27Jh8Pp8k6eGHH9aPf/xjHT9+XK5h/BKHw2F5vV6FQiF5PJ4znT6QVvFoVO9u2aLg1q0jvq3zVq/WlCVLRmFWQGYM93F8RK9JhUIhSVJJSYkkqaWlRdFoVFVVVYkxM2fOVEVFhZqamiRJTU1Nmj17diJQklRdXa1wOKxDhw6d9vtEIhGFw+GkC5BteLoPSN0ZRyoej2vt2rW69NJLNWvWLElSMBiUy+VScXFx0lifz6dgMJgY89FAndp/at/p1NXVyev1Ji7l5eVnOm0gc4gUkLIzjlRtba0OHjyoLVu2jOZ8Tmv9+vUKhUKJSwen3yIbjfJrUnzwIcaD3DO50urVq7V9+3bt2bNHU6dOTWz3+/0aHBxUT09P0tFUZ2en/H5/YszevXuTbu/U2X+nxnyc2+2Wm3+BIts5HHKMUqTinA2LcSKlIyljjFavXq1t27Zp165dmj59etL+efPmKS8vTw0NDYltbW1tam9vVyAQkCQFAgEdOHBAXV1diTH19fXyeDyqrKwcyX0BrOZwOORwOEbltuKDg3w6L8aFlI6kamtr9cQTT+jpp5/WxIkTE68heb1eFRQUyOv1auXKlVq3bp1KSkrk8Xh0yy23KBAIaNGiRZKkJUuWqLKyUsuXL9d9992nYDCoO+64Q7W1tRwtAcNkIhGJp/swDqQUqU2bNkmSvva1ryVtf/TRR/X9739fknT//ffL6XRq6dKlikQiqq6u1kMPPZQYm5OTo+3bt2vVqlUKBAKaMGGCVqxYoXvuuWdk9wQYR+KDgzLGaHSOywB7pRSp4bxQm5+fr40bN2rjxo2fOua8887Tjh07UvnWAD4iPjjIkRTGBd67D8hC8UhEhtekMA4QKSALcSSF8YJIAVkozokTGCeIFJBGDpdLjry8Ed9OaO9exT7yxs7AWEWkgDQqKC9X/ii8rZeJRnnHCYwLRApII0durhy5Z/RGL8C4RKSANHLk5spJpIBhI1JAGnEkBaSGSAFp5MjJIVJACogUkEYcSQGpIVJAGhEpIDVECkgjIgWkhkgBaeTIyZEjJ2dUbuvUO6EDYxmRAtJotD70UJLiAwOjdluArYgUkKXikUimpwCcdUQKyFLxkyczPQXgrCNSQJbi6T6MB0QKyFIxnu7DOECkgCzF030YD4gUkKV4ug/jAZEC0iwnP18ahVPRT+zePQqzAexGpIA08yxYIGdBwSe2R2IxHe7p0b7jx/VmOPy5f6jLkRTGA96fBUgzp8uV+KPeUyF6q69P+0+cUNmECSrKzZXbyb8fAYlIAWnndLuTnu47Egrpr6GQvur3a5LbrZxRfFcKINsRKSDNPhqpjv5+tYVCqv7CF1T4jzeefae/X/u7u9Ubjeqc/HwFzjlHE/LyMjllIGOIFJBmp57ui8bjajlxQv/L51Nhbq6MMTra16e79u/XW319GojF5MnL06xJk/TvCxYoj6cAMQ7xUw+k2akjqbd6e+UvKFCJyyVJ+r99fbrxxRd1JBTSyVhMRlIoGtWLXV1a09ysEwMDihujk0NDmb0DQBoRKSDNnC6X5HAoHI3K63LJ+Y+n/n5x6JBC0ehpr7P3vfdUf+yYBmIxPffOOxqIxdI5ZSBjiBSQZk63Ww6HQ85/XFL5+I63+vq0vaND7/T3n8UZAvYgUkCaOfLyVHj++Wd03aF4XBeVlKiIEykwThApIM0cDoem/su/KG6M4sYk/laqprxceZ9yVDWtqEhzSkr0xYkTVTZhgnyn+WNgYCwiUkAG5Hm9ml5drdDgoGL/iFR1WZnu+vKXlZ+Tk/jFzHE4NNnt1v9ZsECVxcV6u69PXygslCT5li7N0OyB9OEUdCADnAUFuujyy/Vifb26IxGVFhTI4XCouqxMUwsLtf3vf9eJgQFNKyrSddOna7Lbrb5oVK3d3frf06bJfe65mhQIjOrH0QM2IlJABjgcDk2eN09XLl+uZ379ay3x+TQhL08Oh0OzJk3SrEmTksZHYjE1Hz+uuZMmqWDKFJUtX65cjydDswfSh6f7gAxxut2a98//rEuvvlp/7urS8YEBxeLxpDFxY9QXjeqFzk65cnL0palTde5116n4kkvkyMnJ0MyB9OFICsignMJCXf6v/yqPz6eG3/1OJe+9pwm5uXI4HDLGaCAW098/+EAzJk3Sgjlz5P/2t1Xyta/xNB/GDYf5vM8DsFA4HJbX61UoFJKHpzwwBsQjEXUfPqxX6+sVfO01RYNBxSIRFXo8umDmTE0NBOSdP18FFRUcQWFMGO7jOEdSgAWcbrcmX3SRqiorFTt5UmZoSCYelyMnR868PDkLC+XM5dcV4w8/9YAlHA6HHG73h+/tB0ASJ04AACxGpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYK6VIbdq0SXPmzJHH45HH41EgENDOnTsT+wcGBlRbW6vJkyerqKhIS5cuVWdnZ9JttLe3q6amRoWFhSotLdVtt92moaGh0bk3AIAxJaVITZ06Vffee69aWlr0yiuv6Otf/7quvvpqHTp0SJJ066236plnntHWrVvV2NioY8eO6dprr01cPxaLqaamRoODg3rppZf02GOPafPmzdqwYcPo3isAwNhgRmjSpEnmkUceMT09PSYvL89s3bo1se/IkSNGkmlqajLGGLNjxw7jdDpNMBhMjNm0aZPxeDwmEokM+3uGQiEjyYRCoZFOHwCQAcN9HD/j16RisZi2bNmi/v5+BQIBtbS0KBqNqqqqKjFm5syZqqioUFNTkySpqalJs2fPls/nS4yprq5WOBxOHI2dTiQSUTgcTroAAMa+lCN14MABFRUVye126+abb9a2bdtUWVmpYDAol8ul4uLipPE+n0/BYFCSFAwGkwJ1av+pfZ+mrq5OXq83cSkvL0912gCALJRypC644AK1traqublZq1at0ooVK3T48OGzMbeE9evXKxQKJS4dHR1n9fsBAOyQm+oVXC6Xzj//fEnSvHnztG/fPj3wwAO67rrrNDg4qJ6enqSjqc7OTvn9fkmS3+/X3r17k27v1Nl/p8acjtvtltvtTnWqAIAsN+K/k4rH44pEIpo3b57y8vLU0NCQ2NfW1qb29nYFAgFJUiAQ0IEDB9TV1ZUYU19fL4/Ho8rKypFOBQAwxqR0JLV+/XpdeeWVqqioUG9vr5544gk9//zzeu655+T1erVy5UqtW7dOJSUl8ng8uuWWWxQIBLRo0SJJ0pIlS1RZWanly5frvvvuUzAY1B133KHa2lqOlAAAn5BSpLq6uvS9731P7777rrxer+bMmaPnnntOl19+uSTp/vvvl9Pp1NKlSxWJRFRdXa2HHnoocf2cnBxt375dq1atUiAQ0IQJE7RixQrdc889o3uvAABjgsMYYzI9iVSFw2F5vV6FQiF5PJ5MTwcAkKLhPo7z3n0AAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArDWiSN17771yOBxau3ZtYtvAwIBqa2s1efJkFRUVaenSpers7Ey6Xnt7u2pqalRYWKjS0lLddtttGhoaGslUAABj0BlHat++ffr1r3+tOXPmJG2/9dZb9cwzz2jr1q1qbGzUsWPHdO211yb2x2Ix1dTUaHBwUC+99JIee+wxbd68WRs2bDjzewEAGJvMGejt7TUzZsww9fX15rLLLjNr1qwxxhjT09Nj8vLyzNatWxNjjxw5YiSZpqYmY4wxO3bsME6n0wSDwcSYTZs2GY/HYyKRyLC+fygUMpJMKBQ6k+kDADJsuI/jZ3QkVVtbq5qaGlVVVSVtb2lpUTQaTdo+c+ZMVVRUqKmpSZLU1NSk2bNny+fzJcZUV1crHA7r0KFDp/1+kUhE4XA46QIAGPtyU73Cli1b9Oqrr2rfvn2f2BcMBuVyuVRcXJy03efzKRgMJsZ8NFCn9p/adzp1dXW6++67U50qACDLpXQk1dHRoTVr1ujxxx9Xfn7+2ZrTJ6xfv16hUChx6ejoSNv3BgBkTkqRamlpUVdXly6++GLl5uYqNzdXjY2NevDBB5Wbmyufz6fBwUH19PQkXa+zs1N+v1+S5Pf7P3G236mvT435OLfbLY/Hk3QBAIx9KUVq8eLFOnDggFpbWxOX+fPna9myZYn/zsvLU0NDQ+I6bW1tam9vVyAQkCQFAgEdOHBAXV1diTH19fXyeDyqrKwcpbsFABgLUnpNauLEiZo1a1bStgkTJmjy5MmJ7StXrtS6detUUlIij8ejW265RYFAQIsWLZIkLVmyRJWVlVq+fLnuu+8+BYNB3XHHHaqtrZXb7R6luwUAGAtSPnHi89x///1yOp1aunSpIpGIqqur9dBDDyX25+TkaPv27Vq1apUCgYAmTJigFStW6J577hntqQAAspzDGGMyPYlUhcNheb1ehUIhXp8CgCw03Mdx3rsPAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGCt3ExP4EwYYyRJ4XA4wzMBAJyJU4/fpx7PP01WRurEiROSpPLy8gzPBAAwEr29vfJ6vZ+6PysjVVJSIklqb2//zDs33oXDYZWXl6ujo0MejyfT07EW6zQ8rNPwsE7DY4xRb2+vysrKPnNcVkbK6fzwpTSv18sPwTB4PB7WaRhYp+FhnYaHdfp8wznI4MQJAIC1iBQAwFpZGSm326277rpLbrc701OxGus0PKzT8LBOw8M6jS6H+bzz/wAAyJCsPJICAIwPRAoAYC0iBQCwFpECAFgrKyO1ceNGTZs2Tfn5+Vq4cKH27t2b6Sml1Z49e3TVVVeprKxMDodDTz31VNJ+Y4w2bNigc889VwUFBaqqqtKbb76ZNKa7u1vLli2Tx+NRcXGxVq5cqb6+vjTei7Orrq5OCxYs0MSJE1VaWqprrrlGbW1tSWMGBgZUW1uryZMnq6ioSEuXLlVnZ2fSmPb2dtXU1KiwsFClpaW67bbbNDQ0lM67clZt2rRJc+bMSfzhaSAQ0M6dOxP7WaPTu/fee+VwOLR27drENtbqLDFZZsuWLcblcpnf/OY35tChQ+bGG280xcXFprOzM9NTS5sdO3aYn/70p+b3v/+9kWS2bduWtP/ee+81Xq/XPPXUU+a1114z3/rWt8z06dPNyZMnE2OuuOIKM3fuXPPyyy+bv/zlL+b88883119/fZrvydlTXV1tHn30UXPw4EHT2tpqvvGNb5iKigrT19eXGHPzzTeb8vJy09DQYF555RWzaNEi85WvfCWxf2hoyMyaNctUVVWZ/fv3mx07dpgpU6aY9evXZ+IunRV/+MMfzB//+Efz17/+1bS1tZmf/OQnJi8vzxw8eNAYwxqdzt69e820adPMnDlzzJo1axLbWauzI+sidckll5ja2trE17FYzJSVlZm6uroMzipzPh6peDxu/H6/+dnPfpbY1tPTY9xut3nyySeNMcYcPnzYSDL79u1LjNm5c6dxOBzmnXfeSdvc06mrq8tIMo2NjcaYD9ckLy/PbN26NTHmyJEjRpJpamoyxnz4jwGn02mCwWBizKZNm4zH4zGRSCS9dyCNJk2aZB555BHW6DR6e3vNjBkzTH19vbnssssSkWKtzp6serpvcHBQLS0tqqqqSmxzOp2qqqpSU1NTBmdmj6NHjyoYDCatkdfr1cKFCxNr1NTUpOLiYs2fPz8xpqqqSk6nU83NzWmfczqEQiFJ///NiVtaWhSNRpPWaebMmaqoqEhap9mzZ8vn8yXGVFdXKxwO69ChQ2mcfXrEYjFt2bJF/f39CgQCrNFp1NbWqqamJmlNJH6ezqaseoPZ9957T7FYLOl/siT5fD698cYbGZqVXYLBoCSddo1O7QsGgyotLU3an5ubq5KSksSYsSQej2vt2rW69NJLNWvWLEkfroHL5VJxcXHS2I+v0+nW8dS+seLAgQMKBAIaGBhQUVGRtm3bpsrKSrW2trJGH7Flyxa9+uqr2rdv3yf28fN09mRVpIAzUVtbq4MHD+qFF17I9FSsdMEFF6i1tVWhUEi/+93vtGLFCjU2NmZ6Wlbp6OjQmjVrVF9fr/z8/ExPZ1zJqqf7pkyZopycnE+cMdPZ2Sm/35+hWdnl1Dp81hr5/X51dXUl7R8aGlJ3d/eYW8fVq1dr+/bt2r17t6ZOnZrY7vf7NTg4qJ6enqTxH1+n063jqX1jhcvl0vnnn6958+aprq5Oc+fO1QMPPMAafURLS4u6urp08cUXKzc3V7m5uWpsbNSDDz6o3Nxc+Xw+1uosyapIuVwuzZs3Tw0NDYlt8XhcDQ0NCgQCGZyZPaZPny6/35+0RuFwWM3NzYk1CgQC6unpUUtLS2LMrl27FI/HtXDhwrTP+Wwwxmj16tXatm2bdu3apenTpyftnzdvnvLy8pLWqa2tTe3t7UnrdODAgaSg19fXy+PxqLKyMj13JAPi8bgikQhr9BGLFy/WgQMH1NramrjMnz9fy5YtS/w3a3WWZPrMjVRt2bLFuN1us3nzZnP48GFz0003meLi4qQzZsa63t5es3//frN//34jyfz85z83+/fvN2+//bYx5sNT0IuLi83TTz9tXn/9dXP11Vef9hT0L3/5y6a5udm88MILZsaMGWPqFPRVq1YZr9drnn/+efPuu+8mLh988EFizM0332wqKirMrl27zCuvvGICgYAJBAKJ/adOGV6yZIlpbW01zz77rDnnnHPG1CnDt99+u2lsbDRHjx41r7/+urn99tuNw+Ewf/rTn4wxrNFn+ejZfcawVmdL1kXKGGN++ctfmoqKCuNyucwll1xiXn755UxPKa12795tJH3ismLFCmPMh6eh33nnncbn8xm3220WL15s2trakm7jxIkT5vrrrzdFRUXG4/GYG264wfT29mbg3pwdp1sfSebRRx9NjDl58qT54Q9/aCZNmmQKCwvNt7/9bfPuu+8m3c5bb71lrrzySlNQUGCmTJlifvSjH5loNJrme3P2/OAHPzDnnXeecblc5pxzzjGLFy9OBMoY1uizfDxSrNXZwUd1AACslVWvSQEAxhciBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArPX/AOT+iH4f/XcYAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-521.8353094207628"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
